{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a682ea0b",
   "metadata": {},
   "source": [
    "# BentoML PyTorch MNIST Tutorial\n",
    "\n",
    "Link to source code: https://github.com/bentoml/BentoML/tree/main/examples/pytorch_mnist/\n",
    "\n",
    "Install required dependencies:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ad00863",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45393b74",
   "metadata": {},
   "source": [
    "## Define the model\n",
    "\n",
    "First let's define a simple PyTorch network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caeff07d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "class SimpleConvNet(nn.Module):\n",
    "    \"\"\"\n",
    "    Simple Convolutional Neural Network\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.layers = nn.Sequential(\n",
    "            nn.Conv2d(1, 10, kernel_size=3),\n",
    "            nn.ReLU(),\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(26 * 26 * 10, 50),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(50, 20),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(20, 10),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.layers(x)\n",
    "\n",
    "    def predict(self, inp):\n",
    "        \"\"\"predict digit for input\"\"\"\n",
    "        self.eval()\n",
    "        with torch.no_grad():\n",
    "            raw_output = self(inp)\n",
    "            _, pred = torch.max(raw_output, 1)\n",
    "            return pred"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38888f0a",
   "metadata": {},
   "source": [
    "## Training and Saving the model\n",
    "\n",
    "Then we define a simple PyTorch network and some helper functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c62db15c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "from torchvision.datasets import MNIST\n",
    "from torch.utils.data import DataLoader, ConcatDataset\n",
    "from torchvision import transforms\n",
    "from sklearn.model_selection import KFold\n",
    "\n",
    "import bentoml\n",
    "\n",
    "# reproducible setup for testing\n",
    "seed = 42\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "torch.backends.cudnn.benchmark = False\n",
    "torch.backends.cudnn.deterministic = True\n",
    "\n",
    "\n",
    "def _dataloader_init_fn(worker_id):\n",
    "    np.random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "539b5097",
   "metadata": {},
   "outputs": [],
   "source": [
    "K_FOLDS = 5\n",
    "NUM_EPOCHS = 5\n",
    "LOSS_FUNCTION = nn.CrossEntropyLoss()\n",
    "\n",
    "\n",
    "def get_dataset():\n",
    "    # Prepare MNIST dataset by concatenating Train/Test part; we split later.\n",
    "    train_set = MNIST(\n",
    "        os.getcwd(), download=True, transform=transforms.ToTensor(), train=True\n",
    "    )\n",
    "    test_set = MNIST(\n",
    "        os.getcwd(), download=True, transform=transforms.ToTensor(), train=False\n",
    "    )\n",
    "    return train_set, test_set\n",
    "\n",
    "\n",
    "def train_epoch(model, optimizer, loss_function, train_loader, epoch, device=\"cpu\"):\n",
    "    # Mark training flag\n",
    "    model.train()\n",
    "    for batch_idx, (inputs, targets) in enumerate(train_loader):\n",
    "        inputs, targets = inputs.to(device), targets.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(inputs)\n",
    "        loss = loss_function(outputs, targets)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if batch_idx % 499 == 0:\n",
    "            print(\n",
    "                \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\".format(\n",
    "                    epoch,\n",
    "                    batch_idx * len(inputs),\n",
    "                    len(train_loader.dataset),\n",
    "                    100.0 * batch_idx / len(train_loader),\n",
    "                    loss.item(),\n",
    "                )\n",
    "            )\n",
    "\n",
    "\n",
    "def test_model(model, test_loader, device=\"cpu\"):\n",
    "    correct, total = 0, 0\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (inputs, targets) in enumerate(test_loader):\n",
    "            inputs, targets = inputs.to(device), targets.to(device)\n",
    "            outputs = model(inputs)\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += targets.size(0)\n",
    "            correct += (predicted == targets).sum().item()\n",
    "\n",
    "    return correct, total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d2db4a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load data\n",
    "train_set, test_set = get_dataset()\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    test_set,\n",
    "    batch_size=10,\n",
    "    sampler=torch.utils.data.RandomSampler(test_set),\n",
    "    worker_init_fn=_dataloader_init_fn,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "788c19a0",
   "metadata": {},
   "source": [
    "### Cross Validation\n",
    "\n",
    "We can do some cross validation and the results can be saved with the model as metadata\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b2fdd72",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cross_validate(dataset, epochs=NUM_EPOCHS, k_folds=K_FOLDS):\n",
    "    results = {}\n",
    "\n",
    "    # Define the K-fold Cross Validator\n",
    "    kfold = KFold(n_splits=k_folds, shuffle=True)\n",
    "\n",
    "    print(\"--------------------------------\")\n",
    "\n",
    "    # K-fold Cross Validation model evaluation\n",
    "    for fold, (train_ids, test_ids) in enumerate(kfold.split(dataset)):\n",
    "        print(f\"FOLD {fold}\")\n",
    "        print(\"--------------------------------\")\n",
    "\n",
    "        # Sample elements randomly from a given list of ids, no replacement.\n",
    "        train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)\n",
    "        test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)\n",
    "\n",
    "        # Define data loaders for training and testing data in this fold\n",
    "        train_loader = torch.utils.data.DataLoader(\n",
    "            dataset,\n",
    "            batch_size=10,\n",
    "            sampler=train_subsampler,\n",
    "            worker_init_fn=_dataloader_init_fn,\n",
    "        )\n",
    "        test_loader = torch.utils.data.DataLoader(\n",
    "            dataset,\n",
    "            batch_size=10,\n",
    "            sampler=test_subsampler,\n",
    "            worker_init_fn=_dataloader_init_fn,\n",
    "        )\n",
    "\n",
    "        # Train this fold\n",
    "        model = SimpleConvNet()\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "        loss_function = nn.CrossEntropyLoss()\n",
    "        for epoch in range(epochs):\n",
    "            train_epoch(model, optimizer, loss_function, train_loader, epoch)\n",
    "\n",
    "        # Evaluation for this fold\n",
    "        correct, total = test_model(model, test_loader)\n",
    "        print(\"Accuracy for fold %d: %d %%\" % (fold, 100.0 * correct / total))\n",
    "        print(\"--------------------------------\")\n",
    "        results[fold] = 100.0 * (correct / total)\n",
    "\n",
    "    # Print fold results\n",
    "    print(f\"K-FOLD CROSS VALIDATION RESULTS FOR {K_FOLDS} FOLDS\")\n",
    "    print(\"--------------------------------\")\n",
    "    sum = 0.0\n",
    "    for key, value in results.items():\n",
    "        print(f\"Fold {key}: {value} %\")\n",
    "        sum += value\n",
    "\n",
    "    print(f\"Average: {sum/len(results.items())} %\")\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd06de8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "cv_results = cross_validate(train_set, epochs=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad2104a6",
   "metadata": {},
   "source": [
    "### training the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3d311c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(dataset, epochs=NUM_EPOCHS, device=\"cpu\"):\n",
    "    train_sampler = torch.utils.data.RandomSampler(dataset)\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        dataset,\n",
    "        batch_size=10,\n",
    "        sampler=train_sampler,\n",
    "        worker_init_fn=_dataloader_init_fn,\n",
    "    )\n",
    "    model = SimpleConvNet()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "    loss_function = nn.CrossEntropyLoss()\n",
    "    for epoch in range(epochs):\n",
    "        train_epoch(model, optimizer, loss_function, train_loader, epoch, device)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8df05c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_model = train(train_set)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92d9b23c",
   "metadata": {},
   "source": [
    "### saving the model with some metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fe9c4a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "correct, total = test_model(trained_model, test_loader)\n",
    "metadata = {\n",
    "    \"accuracy\": float(correct) / total,\n",
    "    \"cv_stats\": cv_results,\n",
    "}\n",
    "\n",
    "tag = bentoml.pytorch.save(\n",
    "    \"pytorch_mnist\",\n",
    "    trained_model,\n",
    "    metadata=metadata,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bdf35e55",
   "metadata": {},
   "source": [
    "## Create a BentoML Service for serving the model\n",
    "\n",
    "Note: using `%%writefile` here because `bentoml.Service` instance must be created in a separate `.py` file\n",
    "\n",
    "Even though we have only one model, we can create as many api endpoints as we want. Here we create two end points `predict_ndarray` and `predict_image`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3e2f590",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile service.py\n",
    "\n",
    "import typing as t\n",
    "\n",
    "import numpy as np\n",
    "import PIL.Image\n",
    "from PIL.Image import Image as PILImage\n",
    "\n",
    "import bentoml\n",
    "from bentoml.io import Image\n",
    "from bentoml.io import NumpyNdarray\n",
    "\n",
    "\n",
    "mnist_runner = bentoml.pytorch.get(\n",
    "    \"pytorch_mnist\",\n",
    "    name=\"mnist_runner\",\n",
    "    predict_fn_name=\"predict\",\n",
    ").to_runner()\n",
    "\n",
    "svc = bentoml.Service(\n",
    "    name=\"pytorch_mnist_demo\",\n",
    "    runners=[\n",
    "        mnist_runner,\n",
    "    ],\n",
    ")\n",
    "\n",
    "\n",
    "@svc.api(\n",
    "    input=NumpyNdarray(dtype=\"float32\", enforce_dtype=True),\n",
    "    output=NumpyNdarray(dtype=\"int64\"),\n",
    ")\n",
    "async def predict_ndarray(\n",
    "    inp: \"np.ndarray[t.Any, np.dtype[t.Any]]\",\n",
    ") -> \"np.ndarray[t.Any, np.dtype[t.Any]]\":\n",
    "    assert inp.shape == (28, 28)\n",
    "    # We are using greyscale image and our PyTorch model expect one\n",
    "    # extra channel dimension\n",
    "    inp = np.expand_dims(inp, 0)\n",
    "    output_tensor = await mnist_runner.async_run(inp)\n",
    "    return output_tensor.numpy()\n",
    "\n",
    "\n",
    "@svc.api(input=Image(), output=NumpyNdarray(dtype=\"int64\"))\n",
    "async def predict_image(f: PILImage) -> \"np.ndarray[t.Any, np.dtype[t.Any]]\":\n",
    "    assert isinstance(f, PILImage)\n",
    "    arr = np.array(f)/255.0\n",
    "    assert arr.shape == (28, 28)\n",
    "\n",
    "    # We are using greyscale image and our PyTorch model expect one\n",
    "    # extra channel dimension\n",
    "    arr = np.expand_dims(arr, 0).astype(\"float32\")\n",
    "    output_tensor = await mnist_runner.async_run(arr)\n",
    "    return output_tensor.numpy()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "590147aa",
   "metadata": {},
   "source": [
    "Start a dev model server to test out the service defined above"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29173871",
   "metadata": {},
   "outputs": [],
   "source": [
    "!bentoml serve service.py:svc"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "606c1b36",
   "metadata": {},
   "source": [
    "Now you can use something like:\n",
    "\n",
    "`curl -H \"Content-Type: multipart/form-data\" -F'fileobj=@samples/1.png;type=image/png' http://127.0.0.1:3000/predict_image`\n",
    "    \n",
    "to send an image to the digit recognition service"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7f03564",
   "metadata": {},
   "source": [
    "## Build a Bento for distribution and deployment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "207561bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "bentoml.build(\n",
    "    \"service.py:svc\",\n",
    "    include=[\"*.py\"],\n",
    "    exclude=[\"tests/\"],\n",
    "    description=\"file:./README.md\",\n",
    "    python=dict(\n",
    "        packages=[\"scikit-learn\", \"torch\", \"Pillow\"],\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36306933",
   "metadata": {},
   "source": [
    "Starting a dev server with the Bento build:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec4b9dff",
   "metadata": {},
   "outputs": [],
   "source": [
    "!bentoml serve pytorch_mnist_demo:latest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f05fae93",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  },
  "name": "pytorch_mnist.ipynb"
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
